#!/usr/bin/evn python
# -*- coding: utf-8 -*-
"""
description: 
author: justbk2015
date: 2021/5/22
modify_records:
    - 2021/5/22 justbk2015 create this file
"""

import os
import sys
import yaml
import copy
import re
import getopt


class ArgResolve:
    SHARD_YAML = ['proxy_run_temp.yaml', "jdbc_run_temp.yaml"]
    SHARD_ALG = ['default', 'sample', 'mod', 'range']

    def __init__(self, args):
        self.args = args
        self.sharding_num = -1
        self.sharding_table = {}
        self.database_ips = []
        self.ports = []
        self.is_master = False
        self.temp_sharding_yaml = None
        self.des_dir = None
        self.zookeeper = None
        self.is_build_config = 0
        self.alg_index = 0
        self.is_jdbc = False
        self.thread_range = [0, 0]

    def parse(self):
        opts, args = getopt.getopt(self.args,
                                   'hn:l:t:m:d:p:z:',
                                   ['help', 'num=', 'table=', 'port=', 'listip=',
                                    'master=', 'dir=', 'temp=', 'default=', 'zoo=', 'build=', 'alg=', 'thread='])
        for opt, arg in opts:
            if opt in ('-h', '--help'):
                print(self.get_usage())
                sys.exit(0)
            elif opt in ('-l', '--listip'):
                tmp_ips = arg.split(',')
                for ip in tmp_ips:
                    if ip != '':
                        self.database_ips.append(ip)
            elif opt in ('-n', '--num'):
                self.sharding_num = int(arg)
            elif opt in ('--table',):
                self.sharding_table = self.resolve_sharding_table(arg)
            elif opt in ('-m', '--master'):
                self.is_master = False if (int(arg) == 0) else True
            elif opt in ('-d', '--dir'):
                self.des_dir = arg
            elif opt in ('-t', '--temp'):
                yaml_index = int(arg)
                self.is_jdbc = True if yaml_index == 1 else False
                self.temp_sharding_yaml = self.SHARD_YAML[yaml_index]
            elif opt in ('-p', '--port'):
                tmp_ports = arg.split(',')
                for port in tmp_ports:
                    if port != '':
                        self.ports.append(int(port))
            elif opt in ('--build',):
                self.is_build_config = int(arg)
            elif opt in ('--alg',):
                self.alg_index = int(arg)
            elif opt in ('--thread',):
                tmp_val = arg.split(',')
                if len(tmp_val) == 2:
                    self.thread_range = [int(tmp_val[0]), int(tmp_val[1])]

            elif opt in ('--default',):
                if self.temp_sharding_yaml is None:
                    self.temp_sharding_yaml = self.SHARD_YAML[0]
                if self.des_dir is None:
                    self.des_dir = '.'
            elif opt in ('-z', '--zoo'):
                self.zookeeper = arg
            else:
                pass
        if self.sharding_num == -1:
            self.sharding_num = len(self.database_ips)

    @classmethod
    def resolve_sharding_table(cls, arg):
        if arg is None or arg == "":
            return {}
        tables_config = arg.split(',')
        if len(tables_config) == 0:
            return {}
        kvs = map(lambda kv: (kv[0],kv[1]), map(lambda x: x.strip().split(':'),tables_config))
        return dict(kvs)

    def valid_check(self):
        sys_flag, code, err_detail = False, 1, ""
        if len(self.database_ips) != self.sharding_num:
            sys_flag = True
            code = 1
            err_detail = "sharding_num({0}) not equal self.database_ips.len({1})".format(
                self.sharding_num,
                len(self.database_ips))
        elif len(self.ports) != 0 and len(self.ports) != self.sharding_num:
            sys_flag = True
            code = 1
            err_detail = "sharding_num({0}) not equal self.ports.len({1})".format(
                self.sharding_num,
                len(self.ports))
        elif self.temp_sharding_yaml is None:
            sys_flag = True
            code = 2
            err_detail = "-t args must config, support range:[0,{0})".format(len(self.SHARD_YAML))
        elif self.des_dir is None:
            sys_flag = True
            code = 3
            err_detail = "-d args must config, use \".\" generate cur path"
        elif self.alg_index >= len(self.SHARD_ALG):
            sys_flag = True
            code = 4
            err_detail = "--alg args must in range[0,{0})".format(len(self.SHARD_ALG))
        else:
            pass
        if sys_flag:
            print(self.get_usage())
            print(self.args)
            print('exit with err={0}, message={1}'.format(code, err_detail))
            sys.exit(code)

    @classmethod
    def get_usage(cls):
        return '''Usage:
-h/--help show this info
-n/--num sharding database number, if not set ,use number of -l param
--table sharding mapped table,config like:table:number,table:number. ect: order:5,order1:1
-l/--listip [require]
    the database ip, order is very important, use multi to set more, must match -n value
    you can use like "-l ip" or "-l ip1 -l ip2" or "-l ip1,ip2"
-p/--port 
   the database ip port, order is very important, it match -l param, default use temp file's port
    you can use like "-p port" or "-p port1 -p port2" or "-p port1,port2"
-m/--master is master proxy node, default is False. 0 -> false, 1-> true, skip if -z param not set
-d/--dir [require] the destination path of proxy
-t/--temp [require]
    the proxy config-sharding.yaml template index
    0->proxy_run_temp.yaml
    1->jdbc_run_temp.yaml
-z/--zoo
    set the zookeeper center, like:20.20.20.88:2181
    if setted, will auto change server.yaml
--build default is 0
   generate if the build proxy data config file, 0-> not build, 1-> build
--alg default is 0
   use the sharding alg by index, 0->default 1->sample 2->fast
--thread
   use to set the min and max valid threads, exp: --thread=200,500
--default
    set default temp file index 0 and generate in cur path, like add param: -t 0 -d .
    
usage:
--->for proxy_test:
python3 main_config_proxy.py -l 20.20.20.54 --default #create config-sharding.yaml in cur path
python3 main_config_proxy.py -l 20.20.20.54 -p 3000 -t 0 -d /home/proxy/sharding_confi #create config-sharding.yaml in special path 
python3 main_config_proxy.py -l 20.20.20.54 -l 20.20.20.54 -p 4000 -p 4000 -d /home/proxy/sharding_config -t 0 -m 1 -z 20.20.20.54 #all param example
python3 main_config_proxy.py -l 20.20.20.54 -l 20.20.20.54 -p 4000 -p 4000 -d /home/proxy/sharding_config -t 0 -m 1 -z 20.20.20.54
    --table=bmsql_stock:10,bmsql_district:10,bmsql_order_line:10  # --add table split example

--->for proxy_build_data:
python3 main_config_proxy.py -l 20.20.20.54 --default --build # create config file with broadcastTables for item
--->for proxy fast alg
python3 main_config_proxy.py -l 20.20.20.54 --default --alg 1 # create config file with broadcastTables for item
'''


class PathConf:
    ROOT_PATH = os.path.dirname(os.path.abspath(__file__))

    def __init__(self, temp_yaml, des_dir):
        self.temp_yaml = temp_yaml
        self.des_dir = des_dir
        if not os.path.isabs(self.des_dir):
            self.des_dir = os.path.join(self.ROOT_PATH, self.des_dir)

    def get_output_yaml(self):
        return os.path.join(self.des_dir, "config-sharding.yaml")

    def get_src_yaml(self):
        return os.path.join(self.ROOT_PATH, self.temp_yaml)


class TableConfigFactory:
    def __init__(self, db_count, tb_counts, alg_index, is_build_config=False):
        self.db_count = db_count
        self.tb_counts = tb_counts
        self.alg_index = alg_index
        self.is_build_config = is_build_config

    def get_tb_count(self, tb_name):
        return 1 if tb_name not in self.tb_counts else int(self.tb_counts[tb_name])

    def init_tables(self):
        tables = [
            ('bmsql_warehouse', 'w_id'),
            ('bmsql_config', 'cfg_id'),
            ('bmsql_district', 'd_w_id'),
            ('bmsql_customer', 'c_w_id'),
            ('bmsql_history', 'h_w_id'),
            ('bmsql_new_order', 'no_w_id'),
            ('bmsql_oorder', 'o_w_id'),
            ('bmsql_order_line', 'ol_w_id'),
            ('bmsql_stock', 's_w_id')
        ]
        if not self.is_build_config:
            tables.append(
                ('bmsql_item', 'i_id')
            )
        return tables

    @classmethod
    def generate_all_datasources(cls, standard_props, database_ips, ports, url):
        maps_all = {}
        ds_key = "ds"
        port_replace = len(ports) != 0
        old_ip = cls.find_cur_ip(standard_props[url], port_replace)
        for i, cur_ip in enumerate(database_ips):
            cur_props = copy.deepcopy(standard_props)
            cur_key = "{}_{}".format(ds_key, i)
            if port_replace:
                cur_ip += ':%s' % ports[i]
            cur_props[url] = cur_props[url].replace(old_ip, cur_ip)
            maps_all[cur_key] = cur_props
        return maps_all

    @classmethod
    def find_cur_ip(cls, url, port_replace):
        pattern = 'localhost' if 'localhost' in url else '\.'.join(['\d{1,3}' for i in range(4)])
        if port_replace:
            pattern += ':\d{1,5}'
        match = re.search(pattern, url)
        return match.group(0)

    @classmethod
    def node_expr(cls, name, count):
        """this use to create database and table actual node expr"""
        if count - 1 == 0 and not name.startswith("ds"):
            return name
        return '{name}_${left}{begin}..{end}{right}'.format(name=name,
                                                            left='{',
                                                            right='}',
                                                            begin=0,
                                                            end=count - 1)

    @classmethod
    def generate_alg_expr(cls, column, count):
        return '{column} % {count}'.format(column=column,
                                           count=count)

    def generate_alg_name(self, name, is_table):
        before = 'tb' if is_table else 'ds'
        if not is_table and self.alg_index != 0:
            name = ArgResolve.SHARD_ALG[self.alg_index]
        return '_'.join([before, "inline", name])

    def build_table(self, ds_expr, tb_name, tb_column):
        tb_count = self.get_tb_count(tb_name)
        tb_expr = self.node_expr(tb_name, tb_count)

        values = {"actualDataNodes": '.'.join([ds_expr, tb_expr]), "databaseStrategy": {
            "standard": {
                "shardingColumn": tb_column,
                "shardingAlgorithmName": self.generate_alg_name(tb_name, False)
            }
        }}

        if tb_count > 1:
            table_strategy = {"standard":
                {
                    "shardingColumn": tb_column,
                    "shardingAlgorithmName": self.generate_alg_name(tb_name, True)
                }
            }
            values["tableStrategy"] = table_strategy
        return values

    def build_algorithm(self, tb_name, tb_column):
        all_alg = (
            self._build_algorithm_inner(tb_name, tb_column, False),
            self._build_algorithm_inner(tb_name, tb_column, True)
        )
        return filter(lambda alg: alg[0] is not None, all_alg)

    def _build_algorithm_inner(self, tb_name, tb_column, is_table):
        alg_name = self.generate_alg_name(tb_name, is_table)
        alg_count = self.get_tb_count(tb_name) if is_table else self.db_count
        func_alg = ArgResolve.SHARD_ALG[self.alg_index]
        func_args = {}
        if is_table:
            if alg_count == 1:
                return None, None
            func_alg = "div_mod"
            func_args["div"] = self.db_count
        if not is_table:
            tb_name = "ds"

        func_name = "generate_alg_def_{0}".format(func_alg)
        alg_def = getattr(self, func_name)(tb_name, tb_column, alg_count, **func_args)
        return alg_name, alg_def

    def has_broadcast_table(self):
        return self.is_build_config == 1

    @classmethod
    def build_broadcast_table(cls):
        return ['bmsql_item']

    @classmethod
    def generate_alg_def_default(cls, name, column, count, **kwargs):
        return {
            "props": {
                "algorithm-expression": "{name}_${left}{express}{right}".format(
                    name=name,
                    express=cls.generate_alg_expr(column, count),
                    left='{',
                    right='}')
            },
            "type": "INLINE"
        }

    @classmethod
    def generate_alg_def_sample(cls, name, column, count, **kwargs):
        return {
            "props": {
                "strategy": "STANDARD",
                "algorithmClassName": "org.apache.shardingsphere.proxy.alg.SampleAlg",
                "sharding-count": count
            },
            "type": "CLASS_BASED"
        }

    @classmethod
    def generate_alg_def_mod(cls, name, column, count, **kwargs):
        return {
            "props": {
                "sharding-count": count
            },
            "type": "MOD"
        }

    @classmethod
    def generate_alg_def_range(cls, name, column, count, **kwargs):
        if count == 2:
            return {
                "props": {
                    "sharding-ranges": 1001
                },
                "type": "BOUNDARY_RANGE"
            }
        return {
            "props": {
                "range-lower": 1001,
                "range-upper": ((count - 1) * 1000 + 1),
                "sharding-volume": 1000
            },
            "type": "VOLUME_RANGE"
        }

    @classmethod
    def generate_alg_def_div_mod(cls, name, column, count, **kwargs):
        return {
            "props": {
                "sharding-count": count,
                "divide-before-mod": int(kwargs['div'])
            },
            "type": "DIVIDE_BEFORE_MOD"
        }


class BaseFormat:

    def __init__(self, arg_resolve):
        self.arg_resolve = arg_resolve
        self.path_conf = PathConf(self.arg_resolve.temp_sharding_yaml, self.arg_resolve.des_dir)
        self.table_factory = None

    def get_url_name(self):
        return "url"

    def get_server_file(self):
        return "server.yaml"

    def get_minthread_name(self):
        return "minPoolSize"

    def get_maxthread_name(self):
        return "maxPoolSize"

    def replace_thread_base(self, standard_prop):
        min_t, max_t = self.arg_resolve.thread_range
        if max_t == 0:
            return
        standard_prop[self.get_minthread_name()] = min_t
        standard_prop[self.get_maxthread_name()] = max_t

    def format(self):
        prop = self.load()
        table_factory = TableConfigFactory(self.arg_resolve.sharding_num,
                                           self.arg_resolve.sharding_table,
                                           self.arg_resolve.alg_index,
                                           self.arg_resolve.is_build_config)
        standard_prop = prop.get("dataSources").pop("ds_0")
        self.replace_thread_base(standard_prop)
        prop['dataSources'] = table_factory.generate_all_datasources(standard_prop,
                                                                     self.arg_resolve.database_ips,
                                                                     self.arg_resolve.ports,
                                                                     self.get_url_name())

        ds_expr = table_factory.node_expr("ds", table_factory.db_count)
        for tb_name, tb_column in table_factory.init_tables():
            table_def = table_factory.build_table(ds_expr, tb_name, tb_column)
            prop.get('rules').get('tables')[tb_name] = table_def
            for alg_name, alg_def in table_factory.build_algorithm(tb_name, tb_column):
                if alg_name is not None:
                    prop.get('rules').get('shardingAlgorithms')[alg_name] = alg_def

        if table_factory.has_broadcast_table():
            _prop_config = prop.get('rules')
            if 'broadcastTables' not in _prop_config:
                _prop_config["broadcastTables"] = []
            _prop_config["broadcastTables"].extend(table_factory.build_broadcast_table())

        self.save(prop)
        return prop

    def save(self, prop):
        file = self.path_conf.get_output_yaml()
        with open(file, "wt", encoding='utf-8') as f:
            yaml.dump(prop, f)

        self.add_rule_tag(file)

    @classmethod
    def add_rule_tag(cls, file):
        lines = None
        with open(file, "rt", encoding='utf-8') as f:
            lines = f.readlines()

        i = 0
        for line in lines:
            if line.strip() == 'rules:':
                break
            i += 1
        lines.insert(i + 1, '- !SHARDING\n')
        with open(file, "wt", encoding='utf-8') as f:
            f.writelines(lines)

    def load(self):
        with open(self.path_conf.get_src_yaml(), "rt", encoding='utf-8') as f:
            count = f.read()
            return yaml.load(count)

    def change_server_yaml(self):

        def server_proc(server_line, begin):
            if self.arg_resolve.zookeeper is None:
                return server_line
            return server_line[:server_line.index(begin) + len(begin) + 2] \
                   + self.arg_resolve.zookeeper + os.linesep

        def overwrite_proc(overwrite_line, begin):
            is_master = self.arg_resolve.is_master
            if is_master and 'false' in overwrite_line:
                return overwrite_line.replace('false', 'true')
            elif not is_master and 'true' in overwrite_line:
                return overwrite_line.replace('true', 'false')
            else:
                return overwrite_line

        server_yaml = os.path.join(self.arg_resolve.des_dir, self.get_server_file())
        with open(server_yaml, "rt", encoding='utf-8') as f:
            lines = f.readlines()
        new_lines = []
        procs = [('server-lists', server_proc), ('overwrite', overwrite_proc)]
        for line in lines:
            is_add = False
            for tag, proc in procs:
                if tag in line:
                    is_add = True
                    new_lines.append(proc(line, tag))
                    procs.remove((tag, proc))
                    break
            if not is_add:
                new_lines.append(line)

        with open(server_yaml, "wt", encoding='utf-8') as f:
            f.writelines(new_lines)


class ProxyFormat(BaseFormat):

    def __init__(self, arg_resolve):
        super(ProxyFormat, self).__init__(arg_resolve)


class JdbcFormat(BaseFormat):

    def __init__(self, arg_resolve):
        super(JdbcFormat, self).__init__(arg_resolve)

    def get_url_name(self):
        return "jdbcUrl"

    def get_server_file(self):
        return "config-sharding.yaml"

    def get_minthread_name(self):
        return "minimumIdle"

    def get_maxthread_name(self):
        return "maximumPoolSize"


if __name__ == '__main__':
    arg_resolve = ArgResolve(sys.argv[1:])
    arg_resolve.parse()
    arg_resolve.valid_check()
    format_cls = JdbcFormat if arg_resolve.is_jdbc else ProxyFormat
    cur_format = format_cls(arg_resolve)
    result = cur_format.format()
    if arg_resolve.zookeeper is not None:
        cur_format.change_server_yaml()
    sys.exit(0)
